{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "iid = True\n",
    "num_users = 100\n",
    "frac = 0.1\n",
    "local_ep = 1\n",
    "\n",
    "dataset = 'mnist'\n",
    "# dataset = 'cifar10'\n",
    "# dataset = 'cifar100'\n",
    "\n",
    "shard_per_user = 10\n",
    "\n",
    "if dataset == 'mnist':\n",
    "    model = 'mlp'\n",
    "    rd_lg = 100\n",
    "    rd_fed = 800 + int(rd_lg*0.15)\n",
    "elif dataset == 'cifar10':\n",
    "    model = 'cnn'\n",
    "    rd_lg = 100\n",
    "    rd_fed = 1800 + int(rd_lg*0.04)\n",
    "elif dataset == 'cifar100':\n",
    "    model = 'cnn'\n",
    "    rd_fed = 1800\n",
    "    rd_lg = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = './save/{}/{}_iid{}_num{}_C{}_le{}/shard{}/'.format(\n",
    "    dataset, model, iid, num_users, frac, local_ep, shard_per_user)\n",
    "runs = os.listdir(base_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_fed = np.zeros(len(runs))\n",
    "acc_local_localtest = np.zeros(len(runs))\n",
    "acc_local_newtest_avg = np.zeros(len(runs))\n",
    "acc_local_newtest_ens = np.zeros(len(runs))\n",
    "lg_metrics = {}\n",
    "\n",
    "for idx, run in enumerate(runs):\n",
    "    # FedAvg\n",
    "    base_dir_fed = os.path.join(base_dir, \"{}/fed\".format(run))\n",
    "    results_path_fed = os.path.join(base_dir_fed, \"results.csv\")\n",
    "    df_fed = pd.read_csv(results_path_fed)\n",
    "\n",
    "    acc_fed[idx] = df_fed.loc[rd_fed - 1]['best_acc']\n",
    "    \n",
    "    # LocalOnly\n",
    "    base_dir_local = os.path.join(base_dir, \"{}/local\".format(run))\n",
    "    results_path_local = os.path.join(base_dir_local, \"results.csv\")\n",
    "    df_local = pd.read_csv(results_path_local)\n",
    "\n",
    "    acc_local_localtest[idx] = df_local.loc[0]['acc_test_local']\n",
    "    acc_local_newtest_avg[idx] = df_local.loc[0]['acc_test_avg']\n",
    "    if 'acc_test_ens' in df_local.columns:\n",
    "        acc_local_newtest_ens[idx] = df_local.loc[0]['acc_test_ens']\n",
    "    else:\n",
    "        acc_local_newtest_ens[idx] = df_local.loc[0]['acc_test_ens_avg']\n",
    "    \n",
    "    # LGFed\n",
    "    base_dir_lg = os.path.join(base_dir, \"{}/lg/\".format(run))\n",
    "    lg_runs = os.listdir(base_dir_lg)\n",
    "    for lg_run in lg_runs:\n",
    "        results_path_lg = os.path.join(base_dir_lg, \"{}/results.csv\".format(lg_run))\n",
    "        df_lg = pd.read_csv(results_path_lg)\n",
    "        \n",
    "        load_fed = int(re.split('best_|.pt', lg_run)[1])\n",
    "        if load_fed not in lg_metrics.keys():\n",
    "            lg_metrics[load_fed] = {'acc_local': np.zeros(len(runs)),\n",
    "                                    'acc_avg': np.zeros(len(runs)),\n",
    "                                    'acc_ens': np.zeros(len(runs))}\n",
    "            \n",
    "        x = df_lg.loc[rd_lg]['best_acc_local']\n",
    "        lg_metrics[load_fed]['acc_local'][idx] = x\n",
    "        idx_acc_local = df_lg[df_lg['best_acc_local'] == x].index[0]\n",
    "        lg_metrics[load_fed]['acc_avg'][idx] = df_lg.loc[idx_acc_local]['acc_test_avg']\n",
    "        if 'acc_test_ens' in df_lg.columns:\n",
    "            lg_metrics[load_fed]['acc_ens'][idx] = df_lg['acc_test_ens'].values[-1]\n",
    "        else:\n",
    "            lg_metrics[load_fed]['acc_ens'][idx] = df_lg['acc_test_ens_avg'].values[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "columns = [\"Run\", \"Local Test\", \"New Test (avg)\", \"New Test (ens)\", \"FedAvg Rounds\", \"LG Rounds\"]\n",
    "results = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "localonly:\t 88.03 +- 0.37\n",
      "localonly_avg:\t 86.24 +- 0.87\n",
      "localonly_ens:\t 91.15 +- 0.27\n"
     ]
    }
   ],
   "source": [
    "str_acc_local_localtest = \"{:.2f} +- {:.2f}\".format(acc_local_localtest.mean(), acc_local_localtest.std())\n",
    "str_acc_local_newtest_avg = \"{:.2f} +- {:.2f}\".format(acc_local_newtest_avg.mean(), acc_local_newtest_avg.std())\n",
    "str_acc_local_newtest_ens = \"{:.2f} +- {:.2f}\".format(acc_local_newtest_ens.mean(), acc_local_newtest_ens.std())\n",
    "\n",
    "print(\"localonly:\\t\", str_acc_local_localtest)\n",
    "print(\"localonly_avg:\\t\", str_acc_local_newtest_avg)\n",
    "print(\"localonly_ens:\\t\", str_acc_local_newtest_ens)\n",
    "\n",
    "results.append([\"LocalOnly\", str_acc_local_localtest, str_acc_local_newtest_avg, str_acc_local_newtest_ens, 0, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "300\n",
      "acc_local:\t97.45 +- 0.12\n",
      "acc_avg:\t97.50 +- 0.10\n",
      "acc_ens:\t97.49 +- 0.09\n",
      "400\n",
      "acc_local:\t97.59 +- 0.08\n",
      "acc_avg:\t97.61 +- 0.08\n",
      "acc_ens:\t97.62 +- 0.08\n",
      "500\n",
      "acc_local:\t97.78 +- 0.13\n",
      "acc_avg:\t97.82 +- 0.14\n",
      "acc_ens:\t97.82 +- 0.13\n",
      "600\n",
      "acc_local:\t97.84 +- 0.10\n",
      "acc_avg:\t97.86 +- 0.08\n",
      "acc_ens:\t97.87 +- 0.10\n",
      "700\n",
      "acc_local:\t97.85 +- 0.09\n",
      "acc_avg:\t97.88 +- 0.09\n",
      "acc_ens:\t97.89 +- 0.09\n",
      "800\n",
      "acc_local:\t97.91 +- 0.10\n",
      "acc_avg:\t97.93 +- 0.07\n",
      "acc_ens:\t97.93 +- 0.07\n"
     ]
    }
   ],
   "source": [
    "for lg_run in sorted(lg_metrics.keys()):\n",
    "    x = [\"LG-FedAvg\"]\n",
    "    print(lg_run)\n",
    "    for array in ['acc_local', 'acc_avg', 'acc_ens']:\n",
    "        mean = lg_metrics[lg_run][array].mean()\n",
    "        std = lg_metrics[lg_run][array].std()\n",
    "        str_acc = \"{:.2f} +- {:.2f}\".format(mean, std)\n",
    "        print(\"{}:\\t{}\".format(array, str_acc))\n",
    "        \n",
    "        x.append(str_acc)\n",
    "    x.append(lg_run)\n",
    "    x.append(rd_lg)\n",
    "    results.append(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fed:\t 97.93 +- 0.08\n"
     ]
    }
   ],
   "source": [
    "str_acc_fed = \"{:.2f} +- {:.2f}\".format(acc_fed.mean(), acc_fed.std())\n",
    "print(\"fed:\\t\", str_acc_fed)\n",
    "results.append([\"FedAvg\", str_acc_fed, str_acc_fed, str_acc_fed, rd_fed, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Local Test</th>\n",
       "      <th>New Test (avg)</th>\n",
       "      <th>New Test (ens)</th>\n",
       "      <th>FedAvg Rounds</th>\n",
       "      <th>LG Rounds</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Run</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>LocalOnly</th>\n",
       "      <td>88.03 +- 0.37</td>\n",
       "      <td>86.24 +- 0.87</td>\n",
       "      <td>91.15 +- 0.27</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LG-FedAvg</th>\n",
       "      <td>97.45 +- 0.12</td>\n",
       "      <td>97.50 +- 0.10</td>\n",
       "      <td>97.49 +- 0.09</td>\n",
       "      <td>300</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LG-FedAvg</th>\n",
       "      <td>97.59 +- 0.08</td>\n",
       "      <td>97.61 +- 0.08</td>\n",
       "      <td>97.62 +- 0.08</td>\n",
       "      <td>400</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LG-FedAvg</th>\n",
       "      <td>97.78 +- 0.13</td>\n",
       "      <td>97.82 +- 0.14</td>\n",
       "      <td>97.82 +- 0.13</td>\n",
       "      <td>500</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LG-FedAvg</th>\n",
       "      <td>97.84 +- 0.10</td>\n",
       "      <td>97.86 +- 0.08</td>\n",
       "      <td>97.87 +- 0.10</td>\n",
       "      <td>600</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LG-FedAvg</th>\n",
       "      <td>97.85 +- 0.09</td>\n",
       "      <td>97.88 +- 0.09</td>\n",
       "      <td>97.89 +- 0.09</td>\n",
       "      <td>700</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>LG-FedAvg</th>\n",
       "      <td>97.91 +- 0.10</td>\n",
       "      <td>97.93 +- 0.07</td>\n",
       "      <td>97.93 +- 0.07</td>\n",
       "      <td>800</td>\n",
       "      <td>100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>FedAvg</th>\n",
       "      <td>97.93 +- 0.08</td>\n",
       "      <td>97.93 +- 0.08</td>\n",
       "      <td>97.93 +- 0.08</td>\n",
       "      <td>815</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              Local Test New Test (avg) New Test (ens)  FedAvg Rounds  \\\n",
       "Run                                                                     \n",
       "LocalOnly  88.03 +- 0.37  86.24 +- 0.87  91.15 +- 0.27              0   \n",
       "LG-FedAvg  97.45 +- 0.12  97.50 +- 0.10  97.49 +- 0.09            300   \n",
       "LG-FedAvg  97.59 +- 0.08  97.61 +- 0.08  97.62 +- 0.08            400   \n",
       "LG-FedAvg  97.78 +- 0.13  97.82 +- 0.14  97.82 +- 0.13            500   \n",
       "LG-FedAvg  97.84 +- 0.10  97.86 +- 0.08  97.87 +- 0.10            600   \n",
       "LG-FedAvg  97.85 +- 0.09  97.88 +- 0.09  97.89 +- 0.09            700   \n",
       "LG-FedAvg  97.91 +- 0.10  97.93 +- 0.07  97.93 +- 0.07            800   \n",
       "FedAvg     97.93 +- 0.08  97.93 +- 0.08  97.93 +- 0.08            815   \n",
       "\n",
       "           LG Rounds  \n",
       "Run                   \n",
       "LocalOnly          0  \n",
       "LG-FedAvg        100  \n",
       "LG-FedAvg        100  \n",
       "LG-FedAvg        100  \n",
       "LG-FedAvg        100  \n",
       "LG-FedAvg        100  \n",
       "LG-FedAvg        100  \n",
       "FedAvg             0  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(results, columns=columns).set_index(\"Run\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
